{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 继承`Module`构造模型\n",
    "\n",
    "`Module`是`nn`模块提供的一个模型构造类，是所有神经网络结构的基类，可以继承它来自定义我们想要的模型。 通常需要实现`__init__`和`forward`两个函数。\n",
    "\n",
    "在定义的类中无须自定义实现反向传播，系统将通过自动求梯度而自动生成反向传播的`backward`函数。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": "MLP(\n  (hidden): Linear(in_features=784, out_features=256, bias=True)\n  (act): ReLU()\n  (output): Linear(in_features=256, out_features=10, bias=True)\n)\n"
    },
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": "tensor([[ 0.0880,  0.0951, -0.0702, -0.0717, -0.0217,  0.2864,  0.0481,  0.0537,\n          0.1295, -0.0825],\n        [ 0.3202,  0.2092, -0.0238, -0.0530,  0.0234,  0.2822,  0.1265, -0.1255,\n          0.0528, -0.0278]], grad_fn=<AddmmBackward>)"
     },
     "metadata": {},
     "execution_count": 2
    }
   ],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "\n",
    "class MLP(nn.Module):\n",
    "    def __init__(self,**kwargs):\n",
    "        super(MLP,self).__init__(**kwargs)\n",
    "        self.hidden = nn.Linear(784,256)\n",
    "        self.act = nn.ReLU()\n",
    "        self.output = nn.Linear(256,10)\n",
    "\n",
    "    def forward(self,x):\n",
    "        a = self.act(self.hidden(x))\n",
    "        return self.output(a)\n",
    "\n",
    "x = torch.rand(2,784)\n",
    "net = MLP()\n",
    "print(net)\n",
    "net(x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## `Module`的子类\n",
    "\n",
    "### `Sequential`\n",
    "`Sequential`可以接收一个子模块的有序字典(OrderedDict)或者一系列子模块作为参数来逐一添加`Module`的实例，而模型的前向计算就是将这些实例按照添加顺序逐一计算。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": "MySequential(\n  (0): Linear(in_features=784, out_features=256, bias=True)\n  (1): ReLU()\n  (2): Linear(in_features=256, out_features=10, bias=True)\n)\n"
    },
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": "tensor([[ 0.1194,  0.0164,  0.1449, -0.0977, -0.1579,  0.0917,  0.0084, -0.1466,\n         -0.1539,  0.0569],\n        [ 0.1211,  0.0190,  0.1228,  0.0579, -0.0767,  0.0438,  0.0250, -0.0359,\n         -0.1090, -0.0024]], grad_fn=<AddmmBackward>)"
     },
     "metadata": {},
     "execution_count": 4
    }
   ],
   "source": [
    "class MySequential(nn.Module):\n",
    "    from collections import OrderedDict\n",
    "\n",
    "    def __init__(self,*args):\n",
    "        super(MySequential,self).__init__()\n",
    "\n",
    "        if len(args) == 1 and isinstance(args[0],OrderedDict):\n",
    "            for key,module in args.items():\n",
    "                # add_module方法会将module添加到self._modules(一个OrderedDict)\n",
    "                self.add_module(key,module) \n",
    "        else:\n",
    "            for idx,module in enumerate(args):\n",
    "                self.add_module(str(idx),module)\n",
    "\n",
    "    def forward(self,input):\n",
    "        # self._modules返回一个OrderedDict，保证成员按照添加时的顺序遍历\n",
    "        for module in self._modules.values():\n",
    "            input = module(input)\n",
    "\n",
    "        return input\n",
    "\n",
    "net = MySequential(nn.Linear(784,256),nn.ReLU(),nn.Linear(256,10))\n",
    "print(net)\n",
    "net(x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### `ModuleList`\n",
    "\n",
    "`ModuleList`接收一个子模块的列表作为输入，可以像`list`那样进行`append`和`extend`操作\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": "Linear(in_features=256, out_features=10, bias=True)\nModuleList(\n  (0): Linear(in_features=784, out_features=256, bias=True)\n  (1): ReLU()\n  (2): Linear(in_features=256, out_features=10, bias=True)\n)\n"
    }
   ],
   "source": [
    "net = nn.ModuleList([nn.Linear(784,256),nn.ReLU()])\n",
    "net.append(nn.Linear(256,10))\n",
    "print(net[-1])\n",
    "print(net)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### `ModuleDict`\n",
    "\n",
    "`ModuleDict`接收一个子模块的字典作为输入，类似字典那样的进行添加访问操作。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": "Linear(in_features=784, out_features=256, bias=True)\nLinear(in_features=256, out_features=10, bias=True)\nModuleDict(\n  (act): ReLU()\n  (linear): Linear(in_features=784, out_features=256, bias=True)\n  (output): Linear(in_features=256, out_features=10, bias=True)\n)\n"
    }
   ],
   "source": [
    "net = nn.ModuleDict({'linear':nn.Linear(784,256),'act':nn.ReLU()})\n",
    "net['output'] = nn.Linear(256,10)\n",
    "\n",
    "print(net['linear'])\n",
    "print(net.output)\n",
    "print(net)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## 复杂的模型\n",
    "使用`Suquential`,`ModuleList`和`ModuleDict`可以构建一些简单的模型，不需要定义`forward`函数。但是直接继承`Module`可以更灵活的扩展模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": "FancyMLP(\n  (linear): Linear(in_features=20, out_features=20, bias=True)\n)\n"
    },
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": "tensor(-0.4809, grad_fn=<SumBackward0>)"
     },
     "metadata": {},
     "execution_count": 7
    }
   ],
   "source": [
    "class FancyMLP(nn.Module):\n",
    "    def __init__(self,**kwargs):\n",
    "        super(FancyMLP,self).__init__(**kwargs)\n",
    "\n",
    "        self.rand_weight = torch.rand((20,20),requires_grad=False) # 不可训练参数\n",
    "        self.linear = nn.Linear(20,20)\n",
    "\n",
    "    def forward(self,x):\n",
    "        x = self.linear(x)\n",
    "\n",
    "        x = nn.functional.relu(torch.mm(x,self.rand_weight.data) + 1)\n",
    "\n",
    "        # 复用全连接层，等价于两个全连接层共享参数\n",
    "        x = self.linear(x)\n",
    "\n",
    "        # 控制流\n",
    "        while x.norm().item() > 1:\n",
    "            x /= 2\n",
    "        if x.norm().item() < 0.8:\n",
    "            x *= 10\n",
    "\n",
    "        return x.sum()\n",
    "\n",
    "x = torch.rand(2,20)\n",
    "net = FancyMLP()\n",
    "print(net)\n",
    "net(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": "tensor([[0.5943, 0.0439, 0.7584, 0.2527, 0.0257, 0.1521, 0.6511, 0.6634, 0.1116,\n         0.8765, 0.8666, 0.0375, 0.6985, 0.9923, 0.6121, 0.1922, 0.7950, 0.6947,\n         0.0807, 0.8435, 0.3812, 0.4148, 0.1212, 0.7269, 0.0686, 0.1769, 0.7371,\n         0.5583, 0.7601, 0.0224, 0.0012, 0.9065, 0.1495, 0.7758, 0.7068, 0.8274,\n         0.8461, 0.0790, 0.3010, 0.7584],\n        [0.3737, 0.8310, 0.8847, 0.5191, 0.9121, 0.3877, 0.5499, 0.6093, 0.5470,\n         0.4420, 0.6045, 0.6672, 0.4715, 0.8269, 0.4019, 0.6695, 0.0603, 0.0871,\n         0.3629, 0.9018, 0.3537, 0.7331, 0.1684, 0.0983, 0.3853, 0.9679, 0.6830,\n         0.7771, 0.9861, 0.1997, 0.1705, 0.2317, 0.3741, 0.8318, 0.2112, 0.5526,\n         0.4948, 0.4162, 0.3590, 0.4691]])\n"
    },
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": "tensor(6.7878, grad_fn=<SumBackward0>)"
     },
     "metadata": {},
     "execution_count": 8
    }
   ],
   "source": [
    "class NestMLP(nn.Module):\n",
    "    def __init__(self,**kwargs):\n",
    "        super(NestMLP,self).__init__(**kwargs)\n",
    "        self.net = nn.Sequential(nn.Linear(40,30),nn.ReLU())\n",
    "\n",
    "    def forward(self,x):\n",
    "        return self.net(x)\n",
    "\n",
    "net = nn.Sequential(NestMLP(),nn.Linear(30,20),FancyMLP())\n",
    "\n",
    "x = torch.rand(2,40)\n",
    "print(x)\n",
    "net(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.0-final"
  },
  "orig_nbformat": 2,
  "kernelspec": {
   "name": "python3",
   "display_name": "Python 3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}